from stack_mdp.gym_envs.envs.custom_envs import *
from stable_baselines3.common import logger
from games import *


def get_standard_matrix_env(config_dict):

    log_folder = config_dict["log_folder"]
    log = logger.configure(folder=log_folder, format_strings=["csv", "stdout", "tensorboard"])

    game_type = config_dict["matrix_game_name"]

    if game_type == "game_1":
        game_class = MatrixGameOne
    elif game_type == "game_2":
        game_class = MatrixGameTwo
    else:
        game_class = MatrixGameThree

    game = game_class()

    leader = game.list_of_agents[0]
    followers_list = game.list_of_agents[1:]

    env = BaseEnvSimpleMatrixGame(
        game,
        followers_list,
        leader,
        logger=log,
        randomized=config_dict['randomized'],
        randomization_type=config_dict["randomization_type"],
    )

    if config_dict["followers_algorithm"] == "MW":
        env = RLSupervisorMWFollowersWrapper(env)

    elif config_dict["followers_algorithm"] == "Qlearning":
        env = RLSupervisorQFollowersWrapper(env)

    env = StackMDPWrapper(
        env,
        tot_num_reward_steps=config_dict['tot_num_reward_steps'],
        tot_num_eq_steps=config_dict['tot_num_eq_steps'],
        frac_excluded_eq_steps=config_dict['frac_excluded_eq_steps'],
        critic_obs=config_dict['critic_obs'],
    )

    return env


def get_matrix_design_env(config_dict):

    log_folder = config_dict["log_folder"]
    log = logger.configure(folder=log_folder, format_strings=["csv", "stdout", "tensorboard"])

    game_class = MatrixDesignGame

    game = game_class()

    env = BaseEnvMatrixDesignGame(
        game,
        logger=log,
    )

    if config_dict["followers_algorithm"] == "MW":
        env = RLSupervisorMWFollowersWrapper(env)

    elif config_dict["followers_algorithm"] == "Qlearning":
        env = RLSupervisorQFollowersWrapper(env)

    env = StackMDPWrapper(
        env,
        tot_num_reward_steps=config_dict['tot_num_reward_steps'],
        tot_num_eq_steps=config_dict['tot_num_eq_steps'],
        frac_excluded_eq_steps=config_dict['frac_excluded_eq_steps'],
        critic_obs=config_dict['critic_obs'],
    )

    return env


def get_simple_Bayesian_env(config_dict):

    log_folder = config_dict["log_folder"]
    log = logger.configure(folder=log_folder, format_strings=["csv", "stdout", "tensorboard"])

    # games = [MatrixGameOne(), MatrixGameTwo()]
    games = [MatrixGameDiag1(), MatrixGameDiag2(), MatrixGameDiag3()]

    leader = games[0].list_of_agents[0]
    followers_list = games[0].list_of_agents[1:]

    env = BaseSimpleMatrixBayesian(
        games,
        followers_list,
        leader,
        num_messages=config_dict["num_followers_messages"],
        logger=log,
    )

    if config_dict["followers_algorithm"] == "MW":
        env = RLSupervisorMWFollowersWrapper(env)

    elif config_dict["followers_algorithm"] == "Qlearning":
        env = RLSupervisorQFollowersWrapper(env)

    env = StackMDPWrapper(
        env,
        tot_num_reward_steps=config_dict['tot_num_reward_steps'],
        tot_num_eq_steps=config_dict['tot_num_eq_steps'],
        frac_excluded_eq_steps=config_dict['frac_excluded_eq_steps'],
        critic_obs=config_dict['critic_obs'],
    )

    return env


def get_mspm_env(config_dict):

    log_folder = config_dict["log_folder"]
    log = logger.configure(folder=log_folder, format_strings=["csv", "stdout", "tensorboard"])


    env = BaseMessageSPM(
        num_messages=config_dict["num_followers_messages"],
        logger=log,
    )

    if config_dict["followers_algorithm"] == "MW":
        env = RLSupervisorMWFollowersWrapper(env)

    elif config_dict["followers_algorithm"] == "Qlearning":
        env = RLSupervisorQFollowersWrapper(env)

    env = StackMDPWrapper(
        env,
        tot_num_reward_steps=config_dict['tot_num_reward_steps'],
        tot_num_eq_steps=config_dict['tot_num_eq_steps'],
        frac_excluded_eq_steps=config_dict['frac_excluded_eq_steps'],
        critic_obs=config_dict['critic_obs'],
    )

    return env